{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "15ba0126-56d3-426e-977f-aee8a94646a6",
   "metadata": {},
   "source": [
    "# Planning Pattern - ReAct Technique"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d71b47fa-4265-49ef-8e1f-cf8249e12a3f",
   "metadata": {},
   "source": [
    "source : https://github.com/neural-maze/agentic_patterns\n",
    "\n",
    "So, we've seen agents capable of reflecting and using tools to access the outside world. But ... **what about planning**,\n",
    "i.e. deciding what sequence of steps to follow to accomplish a large task?\n",
    "\n",
    "That is exactly what the Planning Pattern provides; ways for the LLM to break a task into **smaller, more easily accomplished subgoals** without losing track of the end goal.\n",
    "\n",
    "The most paradigmatic example of the planning pattern is the [**ReAct**](https://react-lm.github.io/) technique, displayed in the diagram above.\n",
    "\n",
    "In this notebook, you'll learn how this technique actually works. This is the **third lesson** of the \"Agentic Patterns from Scratch\" series. Take a look\n",
    "at the previous lessons if you haven't!\n",
    "\n",
    "* [First Lesson: The Reflection Pattern](https://github.com/neural-maze/agentic_patterns/blob/main/notebooks/reflection_pattern.ipynb)\n",
    "* [Second Lesson: The Tool Pattern](https://github.com/neural-maze/agentic_patterns/blob/main/notebooks/tool_pattern.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bd0dc83a-f11b-469d-bb2c-afbd91f39c5e",
   "metadata": {},
   "source": [
    "## Relevant imports and Groq Client"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "605420e2-4bab-4d0a-90ac-7b9a95fd9976",
   "metadata": {},
   "source": [
    "We start by importing all the libraries we'll be using in this tutorial as well as the Groq client."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "056fc0ed-7ded-490b-ae0b-e7d1fd71430c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "import math\n",
    "import json\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "from groq import Groq\n",
    "\n",
    "# Remember to load the environment variables. You should have the Groq API Key in there :)\n",
    "load_dotenv()\n",
    "\n",
    "MODEL = \"llama-3.1-70b-versatile\"\n",
    "GROQ_CLIENT = Groq()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "46f4363c-b6e0-4c4a-aa08-ad243ddf7911",
   "metadata": {},
   "source": [
    "> If you are not familiar with the `tool` decorator, changes are you are missed the previous tutorial about the Tool Pattern. Check the video [here](https://www.youtube.com/watch?v=ApoDzZP8_ck&t=671s&ab_channel=TheNeuralMaze)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c890c6db",
   "metadata": {},
   "source": [
    "#### Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d753960",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "from dataclasses import dataclass\n",
    "import time\n",
    "\n",
    "from colorama import Fore\n",
    "from colorama import Style\n",
    "\n",
    "def completions_create(client, messages: list, model: str) -> str:\n",
    "    \"\"\"\n",
    "    Sends a request to the client's `completions.create` method to interact with the language model.\n",
    "\n",
    "    Args:\n",
    "        client (Groq): The Groq client object\n",
    "        messages (list[dict]): A list of message objects containing chat history for the model.\n",
    "        model (str): The model to use for generating tool calls and responses.\n",
    "\n",
    "    Returns:\n",
    "        str: The content of the model's response.\n",
    "    \"\"\"\n",
    "    response = client.chat.completions.create(messages=messages, model=model)\n",
    "    return str(response.choices[0].message.content)\n",
    "\n",
    "\n",
    "def build_prompt_structure(prompt: str, role: str, tag: str = \"\") -> dict:\n",
    "    \"\"\"\n",
    "    Builds a structured prompt that includes the role and content.\n",
    "\n",
    "    Args:\n",
    "        prompt (str): The actual content of the prompt.\n",
    "        role (str): The role of the speaker (e.g., user, assistant).\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary representing the structured prompt.\n",
    "    \"\"\"\n",
    "    if tag:\n",
    "        prompt = f\"<{tag}>{prompt}</{tag}>\"\n",
    "    return {\"role\": role, \"content\": prompt}\n",
    "\n",
    "\n",
    "def update_chat_history(history: list, msg: str, role: str):\n",
    "    \"\"\"\n",
    "    Updates the chat history by appending the latest response.\n",
    "\n",
    "    Args:\n",
    "        history (list): The list representing the current chat history.\n",
    "        msg (str): The message to append.\n",
    "        role (str): The role type (e.g. 'user', 'assistant', 'system')\n",
    "    \"\"\"\n",
    "    history.append(build_prompt_structure(prompt=msg, role=role))\n",
    "\n",
    "\n",
    "class ChatHistory(list):\n",
    "    def __init__(self, messages: list | None = None, total_length: int = -1):\n",
    "        \"\"\"Initialise the queue with a fixed total length.\n",
    "\n",
    "        Args:\n",
    "            messages (list | None): A list of initial messages\n",
    "            total_length (int): The maximum number of messages the chat history can hold.\n",
    "        \"\"\"\n",
    "        if messages is None:\n",
    "            messages = []\n",
    "\n",
    "        super().__init__(messages)\n",
    "        self.total_length = total_length\n",
    "\n",
    "    def append(self, msg: str):\n",
    "        \"\"\"Add a message to the queue.\n",
    "\n",
    "        Args:\n",
    "            msg (str): The message to be added to the queue\n",
    "        \"\"\"\n",
    "        if len(self) == self.total_length:\n",
    "            self.pop(0)\n",
    "        super().append(msg)\n",
    "\n",
    "\n",
    "class FixedFirstChatHistory(ChatHistory):\n",
    "    def __init__(self, messages: list | None = None, total_length: int = -1):\n",
    "        \"\"\"Initialise the queue with a fixed total length.\n",
    "\n",
    "        Args:\n",
    "            messages (list | None): A list of initial messages\n",
    "            total_length (int): The maximum number of messages the chat history can hold.\n",
    "        \"\"\"\n",
    "        super().__init__(messages, total_length)\n",
    "\n",
    "    def append(self, msg: str):\n",
    "        \"\"\"Add a message to the queue. The first messaage will always stay fixed.\n",
    "\n",
    "        Args:\n",
    "            msg (str): The message to be added to the queue\n",
    "        \"\"\"\n",
    "        if len(self) == self.total_length:\n",
    "            self.pop(1)\n",
    "        super().append(msg)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "@dataclass\n",
    "class TagContentResult:\n",
    "    \"\"\"\n",
    "    A data class to represent the result of extracting tag content.\n",
    "\n",
    "    Attributes:\n",
    "        content (List[str]): A list of strings containing the content found between the specified tags.\n",
    "        found (bool): A flag indicating whether any content was found for the given tag.\n",
    "    \"\"\"\n",
    "\n",
    "    content: list[str]\n",
    "    found: bool\n",
    "\n",
    "\n",
    "def extract_tag_content(text: str, tag: str) -> TagContentResult:\n",
    "    \"\"\"\n",
    "    Extracts all content enclosed by specified tags (e.g., <thought>, <response>, etc.).\n",
    "\n",
    "    Parameters:\n",
    "        text (str): The input string containing multiple potential tags.\n",
    "        tag (str): The name of the tag to search for (e.g., 'thought', 'response').\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary with the following keys:\n",
    "            - 'content' (list): A list of strings containing the content found between the specified tags.\n",
    "            - 'found' (bool): A flag indicating whether any content was found for the given tag.\n",
    "    \"\"\"\n",
    "    # Build the regex pattern dynamically to find multiple occurrences of the tag\n",
    "    tag_pattern = rf\"<{tag}>(.*?)</{tag}>\"\n",
    "\n",
    "    # Use findall to capture all content between the specified tag\n",
    "    matched_contents = re.findall(tag_pattern, text, re.DOTALL)\n",
    "\n",
    "    # Return the dataclass instance with the result\n",
    "    return TagContentResult(\n",
    "        content=[content.strip() for content in matched_contents],\n",
    "        found=bool(matched_contents),\n",
    "    )\n",
    "\n",
    "\n",
    "def fancy_print(message: str) -> None:\n",
    "    \"\"\"\n",
    "    Displays a fancy print message.\n",
    "\n",
    "    Args:\n",
    "        message (str): The message to display.\n",
    "    \"\"\"\n",
    "    print(Style.BRIGHT + Fore.CYAN + f\"\\n{'=' * 50}\")\n",
    "    print(Fore.MAGENTA + f\"{message}\")\n",
    "    print(Style.BRIGHT + Fore.CYAN + f\"{'=' * 50}\\n\")\n",
    "    time.sleep(0.5)\n",
    "\n",
    "\n",
    "def fancy_step_tracker(step: int, total_steps: int) -> None:\n",
    "    \"\"\"\n",
    "    Displays a fancy step tracker for each iteration of the generation-reflection loop.\n",
    "\n",
    "    Args:\n",
    "        step (int): The current step in the loop.\n",
    "        total_steps (int): The total number of steps in the loop.\n",
    "    \"\"\"\n",
    "    fancy_print(f\"STEP {step + 1}/{total_steps}\")\n",
    "    \n",
    "import json\n",
    "from typing import Callable\n",
    "\n",
    "\n",
    "def get_fn_signature(fn: Callable) -> dict:\n",
    "    \"\"\"\n",
    "    Generates the signature for a given function.\n",
    "\n",
    "    Args:\n",
    "        fn (Callable): The function whose signature needs to be extracted.\n",
    "\n",
    "    Returns:\n",
    "        dict: A dictionary containing the function's name, description,\n",
    "              and parameter types.\n",
    "    \"\"\"\n",
    "    fn_signature: dict = {\n",
    "        \"name\": fn.__name__,\n",
    "        \"description\": fn.__doc__,\n",
    "        \"parameters\": {\"properties\": {}},\n",
    "    }\n",
    "    schema = {\n",
    "        k: {\"type\": v.__name__} for k, v in fn.__annotations__.items() if k != \"return\"\n",
    "    }\n",
    "    fn_signature[\"parameters\"][\"properties\"] = schema\n",
    "    return fn_signature\n",
    "\n",
    "\n",
    "def validate_arguments(tool_call: dict, tool_signature: dict) -> dict:\n",
    "    \"\"\"\n",
    "    Validates and converts arguments in the input dictionary to match the expected types.\n",
    "\n",
    "    Args:\n",
    "        tool_call (dict): A dictionary containing the arguments passed to the tool.\n",
    "        tool_signature (dict): The expected function signature and parameter types.\n",
    "\n",
    "    Returns:\n",
    "        dict: The tool call dictionary with the arguments converted to the correct types if necessary.\n",
    "    \"\"\"\n",
    "    properties = tool_signature[\"parameters\"][\"properties\"]\n",
    "\n",
    "    # TODO: This is overly simplified but enough for simple Tools.\n",
    "    type_mapping = {\n",
    "        \"int\": int,\n",
    "        \"str\": str,\n",
    "        \"bool\": bool,\n",
    "        \"float\": float,\n",
    "    }\n",
    "\n",
    "    for arg_name, arg_value in tool_call[\"arguments\"].items():\n",
    "        expected_type = properties[arg_name].get(\"type\")\n",
    "\n",
    "        if not isinstance(arg_value, type_mapping[expected_type]):\n",
    "            tool_call[\"arguments\"][arg_name] = type_mapping[expected_type](arg_value)\n",
    "\n",
    "    return tool_call\n",
    "\n",
    "\n",
    "class Tool:\n",
    "    \"\"\"\n",
    "    A class representing a tool that wraps a callable and its signature.\n",
    "\n",
    "    Attributes:\n",
    "        name (str): The name of the tool (function).\n",
    "        fn (Callable): The function that the tool represents.\n",
    "        fn_signature (str): JSON string representation of the function's signature.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, name: str, fn: Callable, fn_signature: str):\n",
    "        self.name = name\n",
    "        self.fn = fn\n",
    "        self.fn_signature = fn_signature\n",
    "\n",
    "    def __str__(self):\n",
    "        return self.fn_signature\n",
    "\n",
    "    def run(self, **kwargs):\n",
    "        \"\"\"\n",
    "        Executes the tool (function) with provided arguments.\n",
    "\n",
    "        Args:\n",
    "            **kwargs: Keyword arguments passed to the function.\n",
    "\n",
    "        Returns:\n",
    "            The result of the function call.\n",
    "        \"\"\"\n",
    "        return self.fn(**kwargs)\n",
    "\n",
    "\n",
    "def tool(fn: Callable):\n",
    "    \"\"\"\n",
    "    A decorator that wraps a function into a Tool object.\n",
    "\n",
    "    Args:\n",
    "        fn (Callable): The function to be wrapped.\n",
    "\n",
    "    Returns:\n",
    "        Tool: A Tool object containing the function, its name, and its signature.\n",
    "    \"\"\"\n",
    "\n",
    "    def wrapper():\n",
    "        fn_signature = get_fn_signature(fn)\n",
    "        return Tool(\n",
    "            name=fn_signature.get(\"name\"), fn=fn, fn_signature=json.dumps(fn_signature)\n",
    "        )\n",
    "\n",
    "    return wrapper()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af6735f3-00d3-4a41-acb8-f4afb4a759d4",
   "metadata": {},
   "source": [
    "## A System Prompt for the ReAct Loop"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69b721b1-cedf-4e61-9d0b-0c090972efce",
   "metadata": {},
   "source": [
    "As we did with the Tool Pattern, we also need a System Prompt for the ReAct technique. This System Prompt is very similar, the difference is that it describes the ReAct loop, so that the LLM is aware of\n",
    "the three operations it's allowed to use:\n",
    "\n",
    "1. Thought: The LLM will think about which action to take\n",
    "2. Action: The LLM will use a Tool to \"act on the environment\"\n",
    "3. Observation: The LLM will observe the tool output and reflect on the next thing to do.\n",
    "\n",
    "Another key difference from the Tool Pattern System Prompt is that we are going to enclose all the messages with tags, like these: <thought></thought>, <observation></observation>. We could implement the ReAct logic without these tags, but I found it eeasier for the LLM to understand the instructions this way.\n",
    "\n",
    "Ok! So without further ado, there's the prompt!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87279781-38d4-45df-b8b5-e41c587ba38b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the System Prompt as a constant\n",
    "REACT_SYSTEM_PROMPT = \"\"\"\n",
    "You are a function calling AI model. You operate by running a loop with the following steps: Thought, Action, Observation.\n",
    "You are provided with function signatures within <tools></tools> XML tags.\n",
    "You may call one or more functions to assist with the user query. Don' make assumptions about what values to plug\n",
    "into functions. Pay special attention to the properties 'types'. You should use those types as in a Python dict.\n",
    "\n",
    "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n",
    "\n",
    "<tool_call>\n",
    "{\"name\": <function-name>,\"arguments\": <args-dict>, \"id\": <monotonically-increasing-id>}\n",
    "</tool_call>\n",
    "\n",
    "Here are the available tools / actions:\n",
    "\n",
    "<tools> \n",
    "%s\n",
    "</tools>\n",
    "\n",
    "Example session:\n",
    "\n",
    "<question>What's the current temperature in Madrid?</question>\n",
    "<thought>I need to get the current weather in Madrid</thought>\n",
    "<tool_call>{\"name\": \"get_current_weather\",\"arguments\": {\"location\": \"Madrid\", \"unit\": \"celsius\"}, \"id\": 0}</tool_call>\n",
    "\n",
    "You will be called again with this:\n",
    "\n",
    "<observation>{0: {\"temperature\": 25, \"unit\": \"celsius\"}}</observation>\n",
    "\n",
    "You then output:\n",
    "\n",
    "<response>The current temperature in Madrid is 25 degrees Celsius</response>\n",
    "\n",
    "Additional constraints:\n",
    "\n",
    "- If the user asks you something unrelated to any of the tools above, answer freely enclosing your answer with <response></response> tags.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3db5712-97dd-424c-bdf1-5da0970da03a",
   "metadata": {},
   "source": [
    "## Example step by step"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ddd00354-7ebc-44ec-9606-0a178af59b44",
   "metadata": {},
   "source": [
    "### Defining the Tools"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c6360cd-b9f2-4427-91bc-27f010147c04",
   "metadata": {},
   "source": [
    "Let's build an example that involves the use of three tools, like the following ones."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1d3bec1-679d-4d80-bf41-bf54b00e985c",
   "metadata": {},
   "outputs": [],
   "source": [
    "@tool\n",
    "def sum_two_elements(a: int, b: int) -> int:\n",
    "    \"\"\"\n",
    "    Computes the sum of two integers.\n",
    "\n",
    "    Args:\n",
    "        a (int): The first integer to be summed.\n",
    "        b (int): The second integer to be summed.\n",
    "\n",
    "    Returns:\n",
    "        int: The sum of `a` and `b`.\n",
    "    \"\"\"\n",
    "    return a + b\n",
    "\n",
    "\n",
    "@tool\n",
    "def multiply_two_elements(a: int, b: int) -> int:\n",
    "    \"\"\"\n",
    "    Multiplies two integers.\n",
    "\n",
    "    Args:\n",
    "        a (int): The first integer to multiply.\n",
    "        b (int): The second integer to multiply.\n",
    "\n",
    "    Returns:\n",
    "        int: The product of `a` and `b`.\n",
    "    \"\"\"\n",
    "    return a * b\n",
    "\n",
    "@tool\n",
    "def compute_log(x: int) -> float | str:\n",
    "    \"\"\"\n",
    "    Computes the logarithm of an integer `x` with an optional base.\n",
    "\n",
    "    Args:\n",
    "        x (int): The integer value for which the logarithm is computed. Must be greater than 0.\n",
    "\n",
    "    Returns:\n",
    "        float: The logarithm of `x` to the specified `base`.\n",
    "    \"\"\"\n",
    "    if x <= 0:\n",
    "        return \"Logarithm is undefined for values less than or equal to 0.\"\n",
    "    \n",
    "    return math.log(x)\n",
    "\n",
    "\n",
    "available_tools = {\n",
    "    \"sum_two_elements\": sum_two_elements,\n",
    "    \"multiply_two_elements\": multiply_two_elements,\n",
    "    \"compute_log\": compute_log\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98c3044e-5d5c-44d0-973b-0c7c31f9724b",
   "metadata": {},
   "source": [
    "Remember that the `@tool` operator allows us to convert a Python function into a `Tool` automatically. We cana check that very easily with some of the functions above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b7e2078-9c5d-4687-8c1b-56b36a0194db",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Tool name: \", sum_two_elements.name)\n",
    "print(\"Tool signature: \", sum_two_elements.fn_signature)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "292a747b-f48a-4c36-b62f-b528466c53a8",
   "metadata": {},
   "source": [
    "### Adding the Tools signature to the System Prompt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67cabcdb-4dc4-49b3-9a9a-fc79e1ad12b0",
   "metadata": {},
   "source": [
    "Now, we just concatenate the tools signature and add them to the System Prompt."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b464ea29-400c-482a-83f8-d64fba727ff8",
   "metadata": {},
   "outputs": [],
   "source": [
    "tools_signature = sum_two_elements.fn_signature + \",\\n\" + multiply_two_elements.fn_signature + \",\\n\" + compute_log.fn_signature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "501dd7c6-3ef2-46fb-bc01-a888692b8fa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(tools_signature)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b52ea466-79c0-4898-9c0b-69a07c2ad1db",
   "metadata": {},
   "outputs": [],
   "source": [
    "REACT_SYSTEM_PROMPT = REACT_SYSTEM_PROMPT % tools_signature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40801b47-0021-4800-8364-9c515db1e419",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(REACT_SYSTEM_PROMPT)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2b1fd6c-253a-4a40-ab29-4d68fa976ac1",
   "metadata": {},
   "source": [
    "### ReAct Loop Step 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb771338-8865-4e37-8b8b-02a949f67c41",
   "metadata": {},
   "outputs": [],
   "source": [
    "USER_QUESTION = \"I want to calculate the sum of 1234 and 5678 and multiply the result by 5. Then, I want to take the logarithm of this result\"\n",
    "chat_history = [\n",
    "    {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": REACT_SYSTEM_PROMPT\n",
    "    },\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"<question>{USER_QUESTION}</question>\"\n",
    "    }\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61faa368-e936-4bf2-9a88-df2a35bb205e",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = GROQ_CLIENT.chat.completions.create(\n",
    "    messages=chat_history,\n",
    "    model=MODEL\n",
    ").choices[0].message.content\n",
    "\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9514f9ab-c162-44f6-bc51-247061f811de",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_history.append(\n",
    "    {\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": output\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf7ce6ab-7325-4844-b4d1-1ba590d8aba9",
   "metadata": {},
   "source": [
    "### ReAct Loop Step 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a207669-3ced-43ee-8aad-c8ebf347f797",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call = extract_tag_content(output, tag=\"tool_call\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4614fe15-efa3-4111-b4e7-eb04e160f5f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b7a6279-8cd6-4d18-8a9d-90ea04e552ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call = json.loads(tool_call.content[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f5835594-0105-4452-9aac-7ca76fafaf6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34259246-b8b6-4ae9-aa4b-1bf705fdb3d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_result = available_tools[tool_call[\"name\"]].run(**tool_call[\"arguments\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8724515b-4ba8-46a6-83fc-be472c2f1bd0",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert tool_result == 1234 + 5678"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8196b58-708b-4c9c-a090-eb4ad213a927",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_history.append(\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"<observation>{tool_result}</observation>\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "608684c2-18c3-4e13-885e-a8f1b5e40caa",
   "metadata": {},
   "source": [
    "### ReAct Loop Step 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4eb27636-9080-44bd-991c-5f89f96e2675",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = GROQ_CLIENT.chat.completions.create(\n",
    "    messages=chat_history,\n",
    "    model=MODEL\n",
    ").choices[0].message.content\n",
    "\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34603c55-b114-4ef7-be3a-ef0148eaddb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_history.append(\n",
    "    {\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": output\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c855578-49e3-4bff-981c-16659ec1b4a4",
   "metadata": {},
   "source": [
    "### ReAct Loop Step 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d494dab-7e1d-44e8-a91d-685731d76864",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call = extract_tag_content(output, tag=\"tool_call\")\n",
    "tool_call = json.loads(tool_call.content[0])\n",
    "tool_result = available_tools[tool_call[\"name\"]].run(**tool_call[\"arguments\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33420c29-428e-4392-9c52-57b77d2fcc32",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0b61013e-be7a-457e-8a57-329266b29cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert tool_result == (1234 + 5678) * 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d422e9d2-773b-461a-b04b-5b0de30d59e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_history.append(\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"<observation>{tool_result}</observation>\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b66d90d-6386-4026-8ea5-aa37a143c21c",
   "metadata": {},
   "source": [
    "### ReAct Loop Step 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c5beee0-8684-4948-8e34-084dcab98eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = GROQ_CLIENT.chat.completions.create(\n",
    "    messages=chat_history,\n",
    "    model=MODEL\n",
    ").choices[0].message.content\n",
    "\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "297ce32b-7460-477a-bd5f-2ddd8a45bf11",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_history.append(\n",
    "    {\n",
    "        \"role\": \"assistant\",\n",
    "        \"content\": output\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4c1a336d-2795-41b4-88af-6fed9cebea60",
   "metadata": {},
   "source": [
    "### ReAct Loop Step 6"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "97eeb277-1486-4d8f-8441-8adbde84389e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_call = extract_tag_content(output, tag=\"tool_call\")\n",
    "tool_call = json.loads(tool_call.content[0])\n",
    "tool_result = available_tools[tool_call[\"name\"]].run(**tool_call[\"arguments\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d77818e5-ae29-42d4-8c6e-4527bcf56959",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "694284bb-64de-4f48-b0e2-c870b129a225",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert tool_result == math.log((1234 + 5678) * 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c26c8ae7-a6f3-4d8c-b509-fb39054684bf",
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_history.append(\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": f\"<observation>{tool_result}</observation>\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "72c74a9e-edbe-487a-9cda-17260006e639",
   "metadata": {},
   "source": [
    "### ReAct Loop Step 7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7faf46c-2ee2-4f40-b371-7a42a8885e8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "output = GROQ_CLIENT.chat.completions.create(\n",
    "    messages=chat_history,\n",
    "    model=MODEL\n",
    ").choices[0].message.content\n",
    "\n",
    "print(output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a4cd7f2-cc28-4d8a-a227-8f1346a3b7a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import re\n",
    "\n",
    "from colorama import Fore\n",
    "from dotenv import load_dotenv\n",
    "from groq import Groq\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "BASE_SYSTEM_PROMPT = \"\"\n",
    "\n",
    "\n",
    "REACT_SYSTEM_PROMPT = \"\"\"\n",
    "You operate by running a loop with the following steps: Thought, Action, Observation.\n",
    "You are provided with function signatures within <tools></tools> XML tags.\n",
    "You may call one or more functions to assist with the user query. Don' make assumptions about what values to plug\n",
    "into functions. Pay special attention to the properties 'types'. You should use those types as in a Python dict.\n",
    "\n",
    "For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:\n",
    "\n",
    "<tool_call>\n",
    "{\"name\": <function-name>,\"arguments\": <args-dict>, \"id\": <monotonically-increasing-id>}\n",
    "</tool_call>\n",
    "\n",
    "Here are the available tools / actions:\n",
    "\n",
    "<tools>\n",
    "%s\n",
    "</tools>\n",
    "\n",
    "Example session:\n",
    "\n",
    "<question>What's the current temperature in Madrid?</question>\n",
    "<thought>I need to get the current weather in Madrid</thought>\n",
    "<tool_call>{\"name\": \"get_current_weather\",\"arguments\": {\"location\": \"Madrid\", \"unit\": \"celsius\"}, \"id\": 0}</tool_call>\n",
    "\n",
    "You will be called again with this:\n",
    "\n",
    "<observation>{0: {\"temperature\": 25, \"unit\": \"celsius\"}}</observation>\n",
    "\n",
    "You then output:\n",
    "\n",
    "<response>The current temperature in Madrid is 25 degrees Celsius</response>\n",
    "\n",
    "Additional constraints:\n",
    "\n",
    "- If the user asks you something unrelated to any of the tools above, answer freely enclosing your answer with <response></response> tags.\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "class ReactAgent:\n",
    "    \"\"\"\n",
    "    A class that represents an agent using the ReAct logic that interacts with tools to process\n",
    "    user inputs, make decisions, and execute tool calls. The agent can run interactive sessions,\n",
    "    collect tool signatures, and process multiple tool calls in a given round of interaction.\n",
    "\n",
    "    Attributes:\n",
    "        client (Groq): The Groq client used to handle model-based completions.\n",
    "        model (str): The name of the model used for generating responses. Default is \"llama-3.1-70b-versatile\".\n",
    "        tools (list[Tool]): A list of Tool instances available for execution.\n",
    "        tools_dict (dict): A dictionary mapping tool names to their corresponding Tool instances.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        tools: Tool | list[Tool],\n",
    "        model: str = \"llama-3.1-70b-versatile\",\n",
    "        system_prompt: str = BASE_SYSTEM_PROMPT,\n",
    "    ) -> None:\n",
    "        self.client = Groq()\n",
    "        self.model = model\n",
    "        self.system_prompt = system_prompt\n",
    "        self.tools = tools if isinstance(tools, list) else [tools]\n",
    "        self.tools_dict = {tool.name: tool for tool in self.tools}\n",
    "\n",
    "    def add_tool_signatures(self) -> str:\n",
    "        \"\"\"\n",
    "        Collects the function signatures of all available tools.\n",
    "\n",
    "        Returns:\n",
    "            str: A concatenated string of all tool function signatures in JSON format.\n",
    "        \"\"\"\n",
    "        return \"\".join([tool.fn_signature for tool in self.tools])\n",
    "\n",
    "    def process_tool_calls(self, tool_calls_content: list) -> dict:\n",
    "        \"\"\"\n",
    "        Processes each tool call, validates arguments, executes the tools, and collects results.\n",
    "\n",
    "        Args:\n",
    "            tool_calls_content (list): List of strings, each representing a tool call in JSON format.\n",
    "\n",
    "        Returns:\n",
    "            dict: A dictionary where the keys are tool call IDs and values are the results from the tools.\n",
    "        \"\"\"\n",
    "        observations = {}\n",
    "        for tool_call_str in tool_calls_content:\n",
    "            tool_call = json.loads(tool_call_str)\n",
    "            tool_name = tool_call[\"name\"]\n",
    "            tool = self.tools_dict[tool_name]\n",
    "\n",
    "            print(Fore.GREEN + f\"\\nUsing Tool: {tool_name}\")\n",
    "\n",
    "            # Validate and execute the tool call\n",
    "            validated_tool_call = validate_arguments(\n",
    "                tool_call, json.loads(tool.fn_signature)\n",
    "            )\n",
    "            print(Fore.GREEN + f\"\\nTool call dict: \\n{validated_tool_call}\")\n",
    "\n",
    "            result = tool.run(**validated_tool_call[\"arguments\"])\n",
    "            print(Fore.GREEN + f\"\\nTool result: \\n{result}\")\n",
    "\n",
    "            # Store the result using the tool call ID\n",
    "            observations[validated_tool_call[\"id\"]] = result\n",
    "\n",
    "        return observations\n",
    "\n",
    "    def run(\n",
    "        self,\n",
    "        user_msg: str,\n",
    "        max_rounds: int = 10,\n",
    "    ) -> str:\n",
    "        \"\"\"\n",
    "        Executes a user interaction session, where the agent processes user input, generates responses,\n",
    "        handles tool calls, and updates chat history until a final response is ready or the maximum\n",
    "        number of rounds is reached.\n",
    "\n",
    "        Args:\n",
    "            user_msg (str): The user's input message to start the interaction.\n",
    "            max_rounds (int, optional): Maximum number of interaction rounds the agent should perform. Default is 10.\n",
    "\n",
    "        Returns:\n",
    "            str: The final response generated by the agent after processing user input and any tool calls.\n",
    "        \"\"\"\n",
    "        user_prompt = build_prompt_structure(\n",
    "            prompt=user_msg, role=\"user\", tag=\"question\"\n",
    "        )\n",
    "        if self.tools:\n",
    "            self.system_prompt += (\n",
    "                \"\\n\" + REACT_SYSTEM_PROMPT % self.add_tool_signatures()\n",
    "            )\n",
    "\n",
    "        chat_history = ChatHistory(\n",
    "            [\n",
    "                build_prompt_structure(\n",
    "                    prompt=self.system_prompt,\n",
    "                    role=\"system\",\n",
    "                ),\n",
    "                user_prompt,\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        if self.tools:\n",
    "            # Run the ReAct loop for max_rounds\n",
    "            for _ in range(max_rounds):\n",
    "\n",
    "                completion = completions_create(self.client, chat_history, self.model)\n",
    "\n",
    "                response = extract_tag_content(str(completion), \"response\")\n",
    "                if response.found:\n",
    "                    return response.content[0]\n",
    "\n",
    "                thought = extract_tag_content(str(completion), \"thought\")\n",
    "                tool_calls = extract_tag_content(str(completion), \"tool_call\")\n",
    "\n",
    "                update_chat_history(chat_history, completion, \"assistant\")\n",
    "\n",
    "                print(Fore.MAGENTA + f\"\\nThought: {thought.content[0]}\")\n",
    "\n",
    "                if tool_calls.found:\n",
    "                    observations = self.process_tool_calls(tool_calls.content)\n",
    "                    print(Fore.BLUE + f\"\\nObservations: {observations}\")\n",
    "                    update_chat_history(chat_history, f\"{observations}\", \"user\")\n",
    "\n",
    "        return completions_create(self.client, chat_history, self.model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f2462f18-f4ed-494e-8676-454e883ecc84",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = ReactAgent(tools=[sum_two_elements, multiply_two_elements, compute_log])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bd048fc-1415-4ea1-a9b5-f488fc9a1ef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent.run(user_msg=\"I want to calculate the sum of 1234 and 5678 and multiply the result by 5. Then, I want to take the logarithm of this result\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "817c78d6-760f-4161-b08b-7be9bf1fe010",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "We did it!! A ReAct Agent working as expected, completely from Scratch! 🚀🚀🚀🚀"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
